package topdown;

import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.junit.Test;
import org.junit.Before;
import loader.relationalAlgebraBaseVisitor;
import loader.relationalAlgebraLexer;
import loader.relationalAlgebraParser;
import topdown.concrete_operator.ConcreteOperator;
import topdown.operator.Relation;
import topdown.data_structures.Tuple;
import topdown.operator.*;
import topdown.term.FunctionalSymbol;
import topdown.term.Term;

import java.util.HashSet;
import java.util.Set;

import static org.junit.Assert.assertEquals;

public class TestOperator {
    private relationalAlgebraBaseVisitor<Object> visitor;
    private Relation res;
    private Tuple tuple, tuple2;
    private ConcreteOperator c1, c2;
    private Operator p;

    @Before
    public void setUp() {
        visitor = new relationalAlgebraBaseVisitor<>();
    }

    /**
     * Creates an Operator class based on the input string
     * */
    public Operator loadOperator(String input){
        CharStream inputStream = CharStreams.fromString(input);
        relationalAlgebraLexer operatorLexer = new relationalAlgebraLexer(inputStream);
        CommonTokenStream commonTokenStream = new CommonTokenStream(operatorLexer);
        relationalAlgebraParser operatorParser = new relationalAlgebraParser(commonTokenStream);
        relationalAlgebraParser.QueryContext queryContext = operatorParser.query();
        return (Operator) visitor.visitQuery(queryContext);
    }

    private void compareAndPrintResults(String msg) {
        System.out.println("\n" + msg);

        c1 = p.instance();
        c2 = res.instance();
        tuple = c1.next();
        tuple2 = c2.next();

        Set<Tuple> expected = new HashSet<>();
        Set<Tuple> actual = new HashSet<>();
        while (tuple != null) {
            System.out.println(tuple);
            expected.add(tuple2);
            actual.add(tuple);
            tuple = c1.next();
            tuple2 = c2.next();
        }

        assertEquals(expected, actual);
    }

    @Test
    public void testJoin() {
        /* JOIN([$r, $s], [[A, B], [C, B]], [A, B, C])*/
        /* p(A, B, C) <- r(A, B), s(C, B). */
        p = loadOperator("JOIN([$r, $s], [[A, B], [C, B]], [A, B, C])");

        res = new Relation(new Term[][]{
                {new FunctionalSymbol("a"), new FunctionalSymbol("x"), new FunctionalSymbol("y")},
                {new FunctionalSymbol("c"), new FunctionalSymbol("x"), new FunctionalSymbol("y")},
                {new FunctionalSymbol("a"), new FunctionalSymbol("x"), new FunctionalSymbol("a")},
                {new FunctionalSymbol("c"), new FunctionalSymbol("x"), new FunctionalSymbol("a")},
                {new FunctionalSymbol("b"), new FunctionalSymbol("y"), new FunctionalSymbol("d")},
        });
        compareAndPrintResults("Test 1 - p(A, B, C) = Join(r(A, B), s(C, B)):");
    }

    @Test
    public void testProjectionUsingJoin() {
        /* p(A) <- r(A, B). */
        p = loadOperator("JOIN([$r], [[A, B]], [A])");

        res = new Relation(new Term[][]{
                {new FunctionalSymbol("a")},
                {new FunctionalSymbol("b")},
                {new FunctionalSymbol("c")}
        });

        compareAndPrintResults("\nTest 2 - p(A) = Join(r(A, B)):");
    }

    @Test
    public void testAntiJoin() {
        /* p(B, A) <- r(A, B), ~s(B, C). */
        p = loadOperator("ANTIJOIN([$r, $s], [[A, B], [B, C]], [B, A])");

        res = new Relation(new Term[][]{
                {new FunctionalSymbol("x"), new FunctionalSymbol("a")},
                {new FunctionalSymbol("x"), new FunctionalSymbol("c")},
        });

        compareAndPrintResults("\nTest 3 - p(B, A) = AntiJoin(r(A, B), s(B, C))");
    }

    @Test
    public void testAntiJoinEmpty() {
        /* p(A) <- r(A, B), ~t(). */
        p = loadOperator("ANTIJOIN([$r, $t], [[A, B], []], [A])");

        res = new Relation(new Term[][]{
                {new FunctionalSymbol("a")},
                {new FunctionalSymbol("b")},
                {new FunctionalSymbol("c")}
        });

        compareAndPrintResults("\nTest 4 - p(A) = AntiJoin(r(A, B), q()):");
    }

    @Test
    public void testJoinAntiJoin() {
        /* p(A, B) <- r(A, B), s(A, B), ~s(B, C). */
        p = loadOperator("ANTIJOIN(" +
                "[JOIN([$r, $s], [[A, B], [A, B]], [A, B]), $s], " +
                "[[A, B], [B, C]], " +
                "[A, B])"
        );

        res = new Relation(new Term[][] {
                {new FunctionalSymbol("a"), new FunctionalSymbol("x")}
        });

        compareAndPrintResults("\nTest 5 - p(A, B) = AntiJoin(Join(r(A, B), s(A, B)), s(B, C)):");
    }

    @Test
    public void testUnion() {
        /* p(A, B) <- r(A, B).
           p(A, B) <- s(A, B). */
        p = loadOperator("UNION([$r, $s])");

        res = new Relation(new Term[][] {
                {new FunctionalSymbol("a"), new FunctionalSymbol("x")},
                {new FunctionalSymbol("b"), new FunctionalSymbol("y")},
                {new FunctionalSymbol("c"), new FunctionalSymbol("x")},
                {new FunctionalSymbol("y"), new FunctionalSymbol("x")},
                {new FunctionalSymbol("d"), new FunctionalSymbol("y")},
        });

        compareAndPrintResults("\nTest 6 - p = Union(r, s):");
    }

    @Test
    public void testAncestorFixed() {
        /* p(A, B) <- u(A, B).
           p(A, B) <- u(A, C), p(C, B). */
        p = loadOperator("UNION(" +
                "[$u, JOIN(" +
                "[$u, UNION(" +
                "[$u, JOIN([$u, $u], [[A, B], [A, B]], [A, B])]" +
                ")], [[A, C], [C, B]]," +
                "[A, B])])"
        );

        res = new Relation(new Term[][] {
                {new FunctionalSymbol("a"), new FunctionalSymbol("b")},
                {new FunctionalSymbol("b"), new FunctionalSymbol("c")},
                {new FunctionalSymbol("a"), new FunctionalSymbol("c")},
        });

        compareAndPrintResults("\nTest 7 - ancestor:");
    }

    @Test
    public void testRecursiveAncestor() {
        /* p(A, B) <- u(A, B).
           p(A, B) <- u(A, C), p(C, B). */
        p = loadOperator("REC(p, UNION([$u, JOIN([$u, p], [[A, C], [C, B]], [A, B])]))");

        res = new Relation(new Term[][] {
                {new FunctionalSymbol("a"), new FunctionalSymbol("b")},
                {new FunctionalSymbol("b"), new FunctionalSymbol("c")},
                {new FunctionalSymbol("a"), new FunctionalSymbol("c")},
        });

        compareAndPrintResults("\nTest 10 - recursive ancestor:");
    }

    @Test
    public void testSmallFixPoint() {
        Operator p1 = loadOperator("JOIN([$v, $v], [[A], [B]], [A, B])");
        visitor.addRelation("p1", p1);

        Operator p2 = loadOperator("JOIN([p1, $v], [[A, B], [B]], [A, B])");
        visitor.addRelation("p2", p2);

        Operator p3 = loadOperator("JOIN([p2, $v], [[A, B], [B]], [A, B])");
        visitor.addRelation("p3", p3);

        p = loadOperator("REC(p, JOIN([p1, p2, p3], [[A, B], [A, B], [A, B]], [A, B]))");

        res = new Relation(new Term[][] {
                {new FunctionalSymbol("a"), new FunctionalSymbol("b")},
                {new FunctionalSymbol("a"), new FunctionalSymbol("a")},
                {new FunctionalSymbol("b"), new FunctionalSymbol("a")},
                {new FunctionalSymbol("b"), new FunctionalSymbol("b")},
        });

        compareAndPrintResults("Test 8 - small index test:");
    }

    @Test
    public void testMultiAntiJoin() {
        p = loadOperator("ANTIJOIN([$r, $v, $u], [[A, B], [A], [B, A]], [A, B])");

        res = new Relation(new Term[][] {
                {new FunctionalSymbol("c"), new FunctionalSymbol("x")}
        });

        compareAndPrintResults("Test 9 - p(A, B) = AntiJoin(r(A, B), v(A), u(B, A)):");
    }
}
